{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pyspark\n",
    "from pyspark.sql.functions import udf\n",
    "from pyspark.sql.types import IntegerType, StringType, DoubleType\n",
    "from pyspark.ml import Transformer, Estimator, Pipeline\n",
    "from pyspark.ml.classification import LogisticRegression\n",
    "\n",
    "from mmlspark import ImageReader, ImageFeaturizer, UnrollImage, ImageTransformer, TrainClassifier, \\\n",
    "    SelectColumns, Repartition, ImageFeaturizer, ModelDownloader\n",
    "\n",
    "import numpy as np, pandas as pd, os, sys, time\n",
    "from os.path import join, abspath, exists\n",
    "from urllib.request import urlretrieve"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "mml-deploy": "local",
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Download the CNTK model\n",
    "\n",
    "dataFile = \"flowers_and_labels.parquet\"\n",
    "dataZipFile = dataFile + \".zip\"\n",
    "cdnURL = \"https://mmlspark.azureedge.net/datasets\"\n",
    "dataURL = cdnURL + \"/Flowers/\" + dataZipFile\n",
    "localDataDir = \"/tmp/Flowers/\"\n",
    "localDataFile = join(localDataDir, dataFile)\n",
    "localDataZipFile = join(localDataDir, dataZipFile)\n",
    "modelName = \"ResNet50\"\n",
    "modelDir = \"file:\" + abspath(\"models\")\n",
    "\n",
    "def maybeDownload(url, path):\n",
    "    path = abspath(path)\n",
    "    if not os.path.isfile(path):\n",
    "        print(\"downloading to {}\".format(path))\n",
    "        urlretrieve(url, path)\n",
    "    else:\n",
    "        print(\"found {} skipping download\".format(abspath(path)))\n",
    "\n",
    "def maybeUnzip(zipFilePath, outputDir):\n",
    "    unzippedPath = join(outputDir, os.path.basename(zipFilePath).replace(\".zip\", \"\"))\n",
    "    if os.path.isfile(unzippedPath) or os.path.isdir(unzippedPath):\n",
    "        print(\"found {}, skipping unzipping\".format(unzippedPath))\n",
    "    else:\n",
    "        import zipfile\n",
    "        print(\"unzipping {}\".format(zipFilePath))\n",
    "        with zipfile.ZipFile(zipFilePath, \"r\") as zf:\n",
    "            zf.extractall(outputDir)\n",
    "\n",
    "os.makedirs(abspath(localDataDir), exist_ok=True)\n",
    "maybeDownload(dataURL, localDataZipFile)\n",
    "maybeUnzip(localDataZipFile, localDataDir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "mml-deploy": "hdinsight",
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%%local\n",
    "from os.path import join\n",
    "dataFile = \"flowers_and_labels.parquet\"\n",
    "dataZipFile = dataFile + \".zip\"\n",
    "cdnURL = \"https://mmlspark.azureedge.net/datasets\"\n",
    "dataURL = cdnURL + \"/Flowers/\" + dataZipFile\n",
    "localDataDir = \"/tmp/Flowers/\"\n",
    "localDataFile = join(localDataDir, dataFile)\n",
    "localDataZipFile = join(localDataDir, dataZipFile)\n",
    "\n",
    "import subprocess\n",
    "if subprocess.call([\"hdfs\", \"dfs\", \"-test\", \"-d\", localDataDir]):\n",
    "    from urllib import urlretrieve\n",
    "    print subprocess.check_output(\n",
    "            \"mkdir -p {}\".format(localDataDir),\n",
    "            stderr=subprocess.STDOUT, shell=True)\n",
    "    urlretrieve(dataURL, localDataZipFile)\n",
    "    print subprocess.check_output(\n",
    "            \"unzip {1} -d {0}\".format(localDataDir, localDataZipFile),\n",
    "            stderr=subprocess.STDOUT, shell=True)\n",
    "    print subprocess.check_output(\n",
    "            \"hdfs dfs -mkdir -p {}\".format(localDataDir),\n",
    "            stderr=subprocess.STDOUT, shell=True)\n",
    "    print subprocess.check_output(\n",
    "            \"hdfs dfs -copyFromLocal -f {0} {0}\".format(localDataFile),\n",
    "            stderr=subprocess.STDOUT, shell=True)\n",
    "    print subprocess.check_output(\n",
    "            \"rm -rf {}\".format(localDataFile),\n",
    "            stderr=subprocess.STDOUT, shell=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "mml-deploy": "hdinsight",
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "modelName = \"ResNet50\"\n",
    "modelDir = \"wasb:///models/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = ModelDownloader(spark, modelDir)\n",
    "model = d.downloadByName(modelName)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "mml-deploy": "hdinsight",
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "dataFile = \"flowers_and_labels.parquet\"\n",
    "localDataDir = \"/tmp/Flowers/\"\n",
    "localDataFile = join(localDataDir, dataFile)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the images\n",
    "imagesWithLabels = spark.read.parquet(localDataFile)\n",
    "imagesWithLabels.printSchema()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Smiley face](https://i.imgur.com/p2KgdYL.jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make some featurizers\n",
    "it = ImageTransformer()\\\n",
    "    .setOutputCol(\"scaled\")\\\n",
    "    .resize(height = 60, width = 60)\n",
    "\n",
    "ur = UnrollImage()\\\n",
    "    .setInputCol(\"scaled\")\\\n",
    "    .setOutputCol(\"features\")\n",
    "\n",
    "basicFeaturizer = Pipeline(stages=[it,ur])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cntkFeaturizer = ImageFeaturizer()\\\n",
    "    .setInputCol(\"image\")\\\n",
    "    .setOutputCol(\"features\")\\\n",
    "    .setModelLocation(spark, model.uri)\\\n",
    "    .setLayerNames(model.layerNames)\\\n",
    "    .setCutOutputLayers(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Resnet 18](https://i.imgur.com/Mb4Dyou.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### How does it work?\n",
    "\n",
    "![Convolutional network weights](http://i.stack.imgur.com/Hl2H6.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define some methods to help us experiment\n",
    "def featurize(featurizer, train, test, name):\n",
    "    start = time.time()\n",
    "    sc1 = SelectColumns(cols=[\"features\",\"labels\"])\n",
    "    rep = Repartition(n=4)\n",
    "    pipe = Pipeline(stages=[featurizer,sc1,rep]).fit(train)\n",
    "    trainFeats = pipe.transform(train).cache()\n",
    "    testFeats = pipe.transform(test).cache()\n",
    "\n",
    "    print(\"Featurized {} images with {} featurizer in {} seconds\"\\\n",
    "          .format(trainFeats.count()+testFeats.count(), name, time.time()-start))\n",
    "    sys.stdout.flush()\n",
    "    return trainFeats, testFeats\n",
    "\n",
    "def predict(model, train, test, name):\n",
    "    start=time.time()\n",
    "    sc2 = SelectColumns(cols=([\"scored_labels\",\"labels\"]))\n",
    "    pipe = Pipeline(stages=[model, sc2]).fit(train)\n",
    "    predictions = pipe.transform(test).cache()\n",
    "\n",
    "    print(\"Classified {} images from {} features in {} seconds\"\\\n",
    "          .format(predictions.count(), name, time.time()-start))\n",
    "    sys.stdout.flush()\n",
    "\n",
    "    return predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run the experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Increase or remove the sampling in order to get better results\n",
    "train, test = imagesWithLabels.sample(False, 0.03).randomSplit([.8,.2])\n",
    "train, test = train.repartition(1), test.repartition(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train.cache()\n",
    "test.cache()\n",
    "train.count(), test.count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = TrainClassifier().setModel(LogisticRegression()).setLabelCol(\"labels\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainFeatsBasic, testFeatsBasic = featurize(basicFeaturizer, train, test, \"basic\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "basicPredictions = predict(model, trainFeatsBasic, testFeatsBasic, \"basic\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainFeatsCNTK, testFeatsCNTK = featurize(cntkFeaturizer, train, test, \"cntk\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cntkPredictions = predict(model, trainFeatsCNTK, testFeatsCNTK, \"cntk\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "mml-deploy": "local",
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "basicPredictions = basicPredictions.toPandas()\n",
    "cntkPredictions = cntkPredictions.toPandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "mml-deploy": "hdinsight",
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "basicPredictions.registerTempTable(\"basicPredictions\")\n",
    "cntkPredictions.registerTempTable(\"cntkPredictions\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "mml-deploy": "hdinsight",
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%%sql -q -o basicPredictions\n",
    "select * from basicPredictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "mml-deploy": "hdinsight",
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%%sql -q -o cntkPredictions\n",
    "select * from cntkPredictions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot confusion matrix."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import confusion_matrix\n",
    "import pandas as pd\n",
    "from glob import glob\n",
    "import numpy as np\n",
    "def evaluate(results, name):\n",
    "    y, y_hat = results[\"labels\"],results[\"scored_labels\"]\n",
    "    y = [int(l) for l in y]\n",
    "\n",
    "    accuracy = np.mean([1. if pred==true else 0. for (pred,true) in zip(y_hat,y)])\n",
    "    cm = confusion_matrix(y, y_hat)\n",
    "    cm = cm.astype(\"float\") / cm.sum(axis=1)[:, np.newaxis]\n",
    "\n",
    "    plt.text(40, 10,\"$Accuracy$ $=$ ${}\\%$\".format(round(accuracy*100,1)),fontsize=14)\n",
    "    plt.imshow(cm, interpolation=\"nearest\", cmap=plt.cm.Blues)\n",
    "    plt.colorbar()\n",
    "    plt.xlabel(\"$Predicted$ $label$\", fontsize=18)\n",
    "    plt.ylabel(\"$True$ $Label$\", fontsize=18)\n",
    "    plt.title(\"$Normalized$ $CM$ $for$ ${}$\".format(name))\n",
    "\n",
    "plt.figure(figsize=(12,5))\n",
    "plt.subplot(1,2,1)\n",
    "evaluate(cntkPredictions,\"CNTKModel + LR\")\n",
    "plt.subplot(1,2,2)\n",
    "evaluate(basicPredictions,\"LR\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
